In [17]:
import os

import numpy as np
import cv2
import h5py

from PIL import Image

import IPython

import sys
sys.path.insert(0, '/nfs/kun2/users/riadoshi/')
import mink
import mujoco
import numpy as np
from typing import Dict
import mediapy


In [18]:
# %matplotlib ipympl
import matplotlib.pyplot as plt

In [19]:
e = IPython.embed
DT = 0.02

JOINT_NAMES = [
    "waist",
    "shoulder",
    "elbow",
    "forearm_roll",
    "wrist_angle",
    "wrist_rotate",
]
STATE_NAMES = JOINT_NAMES + ["gripper"]

In [20]:
def load_hdf5(dataset_dir, dataset_name):
    dataset_path = os.path.join(dataset_dir, dataset_name + ".hdf5")
    if not os.path.isfile(dataset_path):
        print(f"Dataset does not exist at \n{dataset_path}\n")
        exit()

    with h5py.File(dataset_path, "r") as root:
        is_sim = root.attrs["sim"]
        compressed = root.attrs.get("compress", False)
        qpos = root["/observations/qpos"][()]
        qvel = root["/observations/qvel"][()]
        effort = root["/observations/effort"][()]
        action = root["/action"][()]
        image_dict = dict()
        for cam_name in root[f"/observations/images/"].keys():
            # skip depth images
            if "_depth" in cam_name:
                continue
            image_dict[cam_name] = root[f"/observations/images/{cam_name}"][()]
        if compressed:
            compress_len = root["/compress_len"][()]

    if compressed:
        for cam_id, cam_name in enumerate(image_dict.keys()):
            # un-pad and uncompress
            padded_compressed_image_list = image_dict[cam_name]
            image_list = []
            for frame_id, padded_compressed_image in enumerate(
                padded_compressed_image_list
            ):  # [:1000] to save memory
                image_len = int(compress_len[cam_id, frame_id])
                compressed_image = padded_compressed_image
                image = cv2.imdecode(compressed_image, 1)
                image_list.append(image)
            image_dict[cam_name] = image_list

    return qpos, qvel, effort, action, image_dict


In [21]:
DATASET_NAME = 'aloha_sushi_cut_full_dataset'
data = np.load(f'saved_trajs/{DATASET_NAME}/traj_data.npz', allow_pickle=True)['arr_0'].item()
qpos_right, image_dict_right = data['proprios'], data['images']
qpos_left, image_dict_left = qpos_right, image_dict_right
# dataset_dir = "/nfs/kun2/users/riadoshi/universal-CoT/data_generation/aloha_keypoints/test_trajs"
# qpos_right, qvel_right, effort_right, action_right, image_dict_right = load_hdf5(dataset_dir, 'episode_12')
# qpos_left, qvel_left, effort_left, action_left, image_dict_left = load_hdf5(dataset_dir, 'episode_10')


In [35]:
mediapy.show_video(image_dict_right, fps=2)

0
This browser does not support the video tag.


In [22]:
model = mujoco.MjModel.from_xml_path("/nfs/kun2/users/riadoshi/mujoco_menagerie/aloha/scene.xml")
configuration = mink.Configuration(model)
data = mujoco.MjData(model)

In [23]:
def get_ee_pos(
    configuration: mink.Configuration,
    q: np.ndarray,
) -> Dict[str, np.ndarray]:
    configuration.update(q=q)  # Make sure same ordering as in mujoco model.
    return {
        "left/gripper": configuration.get_transform_frame_to_world("left/gripper", "site"),
        "right/gripper": configuration.get_transform_frame_to_world("right/gripper", "site"),
    }

In [24]:
poses_arr_left = []
for i in range(len(qpos_left)):
    joint_pos = np.zeros([16])
    joint_pos[:7] = qpos_left[i][:7]
    joint_pos[8:-1] = qpos_left[i][7:]
    poses = get_ee_pos(configuration, joint_pos)
    poses_arr_left.append(poses["left/gripper"].translation())

In [25]:
poses_arr_right = []
for i in range(len(qpos_right)):
    joint_pos = np.zeros([16])
    joint_pos[:7] = qpos_right[i][:7]
    joint_pos[8:-1] = qpos_right[i][7:]
    poses = get_ee_pos(configuration, joint_pos)

    poses_arr_right.append(poses["right/gripper"].translation())

In [26]:
right_coords = np.array([(132, 195), (132, 195), (145, 158), (133, 156), (180, 174), (133, 169), (118, 163), (116, 168), (121, 170), (147, 184), (136, 196), (149, 177), (150, 181), (150, 183), (149, 186), (150, 187), (147, 181), (147, 183), (147, 189), (149, 181), (149, 189), (148, 196), (177, 186), (160, 179), (142, 179), (157, 155), (148, 180), (142, 189), (160, 155), (148, 185), (140, 197), (165, 161), (146, 197), (139, 207), (140, 205), (170, 153), (167, 154), (168, 170), (131, 165), (111, 170), (112, 172), (118, 171), (140, 177), (129, 188), (133, 191), (138, 170), (146, 138), (146, 139), (146, 128), (144, 139), (141, 145), (137, 148), (142, 141), (155, 160), (145, 157), (144, 159), (158, 178), (154, 184), (153, 185), (155, 182), (155, 187), (151, 187), (155, 187), (153, 190), (153, 192), (181, 173), (170, 158), (162, 153), (160, 154), (149, 183), (137, 198), (178, 173), (148, 175), (142, 180), (154, 151), (150, 179), (141, 188), (159, 148), (161, 159), (141, 190), (162, 160), (164, 150), (149, 189), (145, 190), (166, 154), (164, 153), (175, 180), (154, 175), (145, 176), (157, 157), (148, 174), (140, 181), (164, 161), (151, 177), (140, 192), (140, 192), (139, 182), (150, 176), (169, 158), (139, 196), (144, 187), (165, 154)])
left_coords = np.array([(86, 171), (86, 171), (84, 165), (84, 165), (76, 168), (75, 168), (74, 169), (74, 170), (74, 170), (74, 169), (76, 167), (102, 177), (103, 179), (103, 180), (103, 184), (106, 184), (107, 179), (102, 185), (103, 189), (106, 183), (104, 192), (106, 198), (111, 177), (109, 175), (107, 175), (106, 175), (107, 175), (107, 175), (107, 175), (107, 173), (107, 173), (107, 173), (107, 173), (107, 173), (107, 173), (107, 173), (107, 173), (76, 166), (74, 169), (75, 170), (80, 171), (80, 170), (84, 166), (83, 166), (81, 161), (79, 156), (76, 154), (76, 155), (94, 155), (103, 155), (103, 168), (96, 165), (95, 168), (95, 167), (95, 166), (98, 164), (111, 183), (114, 184), (116, 184), (116, 182), (114, 189), (115, 188), (114, 189), (113, 192), (112, 194), (75, 165), (89, 165), (88, 165), (88, 165), (88, 165), (86, 164), (112, 175), (103, 173), (103, 173), (103, 173), (103, 173), (103, 173), (103, 173), (103, 173), (103, 173), (103, 173), (103, 173), (103, 173), (104, 173), (104, 173), (104, 173), (116, 178), (108, 177), (108, 177), (106, 176), (106, 176), (105, 173), (106, 176), (107, 174), (107, 175), (107, 175), (107, 175), (107, 175), (107, 175), (107, 175), (107, 175), (107, 175)])

In [30]:
print(len(left_coords), len(right_coords))

102 102


In [27]:
subsample_poses_arr_left = np.array([x for x in poses_arr_left[::1]])
hg_poses_arr_left = np.concatenate([subsample_poses_arr_left, np.ones([len(subsample_poses_arr_left), 1])], axis=1)
left_transform, res, rank, sv = np.linalg.lstsq(hg_poses_arr_left, left_coords)
print("residual : ", res)

residual :  [5648.43931469 1120.53720297]


  left_transform, res, rank, sv = np.linalg.lstsq(hg_poses_arr_left, left_coords)


In [28]:
subsample_poses_arr_right = np.array([x for x in poses_arr_right[::1]])
hg_poses_arr_right = np.concatenate([subsample_poses_arr_right, np.ones([len(subsample_poses_arr_right), 1])], axis=1)
right_transform, res, rank, sv = np.linalg.lstsq(hg_poses_arr_right, right_coords)
print("residual : ", res)


residual :  [12758.65953939  4325.44232142]


  right_transform, res, rank, sv = np.linalg.lstsq(hg_poses_arr_right, right_coords)


In [14]:
import numpy as np
from PIL import Image

def resize_image(image_np, size, method=Image.LANCZOS):
    # Convert NumPy array to PIL Image
    image = Image.fromarray(image_np)
    
    image_resized = image.resize(size, resample=method)
    image_resized_np = np.array(image_resized)
    image_resized_np = np.clip(np.round(image_resized_np), 0, 255).astype(np.uint8)
    
    return image_resized_np

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2  # For image manipulation

# Assuming traj, configuration, left_transform, and right_transform are defined
# images, joint_poses = traj['observation']['image_primary'].squeeze(), traj['observation']['proprio_bimanual'].squeeze()

processed_images = []  # List to store processed images
# right_images = [resize_image(img, (256, 256)) for img in image_dict_right['cam_high']]
# left_images = [resize_image(img, (256, 256)) for img in image_dict_left['cam_high']]

# since we're visualizing on trajectories from the resized_256_256 dataset, no need to call resize
for idx, (img, pose) in enumerate(zip(image_dict_right, qpos_left)):
    joint_pos = np.zeros([16])
    joint_pos[:7] = pose[:7]
    joint_pos[8:-1] = pose[7:]
    poses = get_ee_pos(configuration, joint_pos)

    left_pos = poses["left/gripper"].translation().reshape(1, -1)
    hg_left_pos = np.concatenate([left_pos, np.ones([1, 1])], axis=1)
    left_pix_pos = (hg_left_pos @ left_transform).reshape(-1)
    left_x, left_y = int(left_pix_pos[0]), int(left_pix_pos[1])

    right_pos = poses["right/gripper"].translation().reshape(1, -1)
    hg_right_pos = np.concatenate([right_pos, np.ones([1, 1])], axis=1)
    right_pix_pos = (hg_right_pos @ right_transform).reshape(-1)
    right_x, right_y = int(right_pix_pos[0]), int(right_pix_pos[1])

    # Convert image to RGB if needed
    img = img if img.shape[-1] == 3 else np.repeat(img[:, :, np.newaxis], 3, axis=-1)

    # Plot points on image
    img_with_points = img.copy()
    cv2.circle(img_with_points, (left_x, left_y), 5, (0, 255, 0), -1)  # Green dot for left gripper
    cv2.circle(img_with_points, (right_x, right_y), 5, (255, 0, 0), -1)  # Blue dot for right gripper

    # Append to list
    processed_images.append(img_with_points)

# Convert to NumPy array and save
processed_images_np = np.array(processed_images)
mediapy.show_video(processed_images_np)


0
This browser does not support the video tag.


In [16]:
mediapy.show_video(processed_images_np, fps=2)


0
This browser does not support the video tag.


In [31]:
right_transform

array([[ 27.67359836,  -8.78031894],
       [ -2.15069723, -40.43744125],
       [ 18.22327217, -14.00618436],
       [138.67154193, 175.21244088]])

In [32]:
left_transform

array([[ 41.7059354 ,  17.0714593 ],
       [-14.98875831, -24.44365958],
       [ -1.10079113, -11.94109217],
       [103.08875805, 179.10297161]])

In [33]:
np.save('transforms/saved_left_arm_transform.npy', left_transform)
np.save('transforms/saved_right_arm_transform.npy', right_transform)