# gr00t n 1.5 deployment details
### they use GR1 as example, so we do not know much details about Droid.
### also we do not know how thet train the model on droid
- image size(1,256,256,3) | uint8| padding and resize, like pi0
- what external_1 nad 2 means? which is left or right? For droid, ext1 is left, ext2 is right
- they have already binarized gripper action in their code? action=1 means close, action=0 means open
- rotation is 'rpy' euler angle

In [None]:
import os
import torch
import gr00t

from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.model.policy import Gr00tPolicy

In [None]:
# change the following paths
MODEL_PATH = "nvidia/GR00T-N1.5-3B"

# REPO_PATH is the path of the pip install gr00t repo and one level up
REPO_PATH = os.path.dirname(os.path.dirname(gr00t.__file__))
DATASET_PATH = os.path.join(REPO_PATH, "demo_data/robot_sim.PickNPlace")
EMBODIMENT_TAG = "oxe_droid"

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from gr00t.experiment.data_config import DATA_CONFIG_MAP

# can add [optional] denoising step in policy

data_config = DATA_CONFIG_MAP["oxe_droid"]
modality_config = data_config.modality_config()
modality_transform = data_config.transform()

policy = Gr00tPolicy(
    model_path=MODEL_PATH,
    embodiment_tag=EMBODIMENT_TAG,
    modality_config=modality_config,
    modality_transform=modality_transform,
    device=device,
)

# print out the policy model architecture
print(policy.model)

In [None]:
import numpy as np

modality_config = policy.modality_config

print(modality_config.keys())

for key, value in modality_config.items():
    if isinstance(value, np.ndarray):
        print(key, value.shape)
    else:
        print(key, value)


In [None]:
from PIL import Image
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
    """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.

    Args:
        images: A batch of images in [..., height, width, channel] format.
        height: The target height of the image.
        width: The target width of the image.
        method: The interpolation method to use. Default is bilinear.

    Returns:
        The resized images in [..., height, width, channel].
    """
    # If the images are already the correct size, return them as is.
    if images.shape[-3:-1] == (height, width):
        return images

    original_shape = images.shape

    images = images.reshape(-1, *original_shape[-3:])
    resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
    return resized.reshape(*original_shape[:-3], *resized.shape[-3:])


def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
    """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
    width without distortion by padding with zeros.

    Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
    """
    cur_width, cur_height = image.size
    if cur_width == width and cur_height == height:
        return image  # No need to resize if the image is already the correct size.

    ratio = max(cur_width / width, cur_height / height)
    resized_height = int(cur_height / ratio)
    resized_width = int(cur_width / ratio)
    resized_image = image.resize((resized_width, resized_height), resample=method)

    zero_image = Image.new(resized_image.mode, (width, height), 0)
    pad_height = max(0, int((height - resized_height) / 2))
    pad_width = max(0, int((width - resized_width) / 2))
    zero_image.paste(resized_image, (pad_width, pad_height))
    assert zero_image.size == (width, height)
    return zero_image

In [None]:
import pyrealsense2 as rs
import numpy as np
import cv2
import matplotlib.pyplot as plt



class RealSenseCamera:
    def __init__(self, serial_number, width=640, height=480, fps=30):
        self.serial = serial_number
        self.pipeline = rs.pipeline()
        self.config = rs.config()
        self.config.enable_device(self.serial)
        self.config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, fps)
        self.pipeline.start(self.config)


    def get_image(self):
        frames = self.pipeline.wait_for_frames()
        color_frame = frames.get_color_frame()
        if not color_frame:
            return None

        bgr = np.asanyarray(color_frame.get_data())
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        return rgb

    def release(self):
        self.pipeline.stop()




# camera serial numbers
serial_left = "948522071060"
serial_wrist = "815412071252"
serial_right = "838212074411"

# initialize cameras
cam_left = RealSenseCamera(serial_left)
cam_wrist = RealSenseCamera(serial_wrist)
cam_right = RealSenseCamera(serial_right)


In [None]:

img_left = cam_left.get_image()
img_wrist = cam_wrist.get_image()
img_right = cam_right.get_image()

# img_left = cv2.resize(img_left, (256, 256))  # shape (256, 256, 3)
# img_wrist = cv2.resize(img_wrist, (256, 256))  # shape (256, 256, 3)
# img_right = cv2.resize(img_right, (256, 256))  # shape (256, 256, 3)

img_left = resize_with_pad(cam_left.get_image(), 256, 256)
img_wrist = resize_with_pad(cam_wrist.get_image(), 256, 256)
img_right = resize_with_pad(cam_right.get_image(), 256, 256)

# combine the two images
combined = np.hstack((img_left, img_wrist, img_right))  # shape (256, 768, 3)

# display the combined image
plt.figure(figsize=(8, 4))
plt.imshow(combined)
plt.title("Left + Wrist + Right ")
plt.axis('off')
plt.show()

In [None]:
cam_left.release()
cam_wrist.release()
cam_right.release()

In [None]:
import numpy as np

step_data = {
    'video.exterior_image_1': np.random.randint(0, 255, (1, 256, 256, 3), dtype=np.uint8),
    'video.exterior_image_2': np.random.randint(0, 255, (1, 256, 256, 3), dtype=np.uint8),
    'video.wrist_image': np.random.randint(0, 255, (1, 256, 256, 3), dtype=np.uint8),
    'state.eef_position': np.random.randn(1, 3).astype(np.float32),
    'state.eef_rotation': np.random.randn(1, 3).astype(np.float32),
    'state.gripper_position': np.random.randn(1, 1).astype(np.float32),
    'annotation.language.language_instruction': ['pick up the red block and place it on the green platform']
}


In [None]:
step_data['state.eef_position'].shape
step_data['state.eef_rotation'].shape
step_data['state.gripper_position'].shape

In [None]:
predicted_action = policy.get_action(step_data)
for key, value in predicted_action.items():
    print(key, value.shape)

In [None]:
# vla_policy_client
import Pyro5.api
import numpy as np
import time

ns = Pyro5.api.locate_ns()  # Locate the name server
uri = ns.lookup("gr00t_controller")  # Look up the registered object by name
# uri = "PYRO:obj_117a8599f4bc45e8a4ab4eb415348147@localhost:43953"  # <-- change URI
controller = Pyro5.api.Proxy(uri)


video_buffer = []

# prompt
prompt = ["pick the orange toy"]

# dummy action
action = np.zeros((16,7), dtype=np.float32)
action_list = action.tolist()  # convert to list for sending


for step in range(5000):

    print(f"\n=== Step {step} ===")
    data_to_send = {
        "action": action_list,
        "step": step
    }
    obs = controller.step(data_to_send)  # result is dict

    img_left = resize_with_pad(cam_left.get_image(), 256, 256)
    img_wrist = resize_with_pad(cam_wrist.get_image(), 256, 256)
    img_right = resize_with_pad(cam_right.get_image(), 256, 256)

    img_left = img_left[None, ...]     # shape: (1, 256, 256, 3)
    img_wrist = img_wrist[None, ...]
    img_right = img_right[None, ...]


    # save images to video buffer
    combined = np.hstack([img_left[0], img_right[0], img_wrist[0]])
    video_buffer.append(combined)


    step_data = {
        'video.exterior_image_1': img_left,
        'video.exterior_image_2': img_right,
        'video.wrist_image': img_wrist,
        'state.eef_position': np.array(obs['robot_pos'], dtype=np.float32),  # (1, 3)
        'state.eef_rotation': np.array(obs['robot_rot'], dtype=np.float32),  # (1, 3)
        'state.gripper_position': np.array(obs['gripper_state'], dtype=np.float32),  # (1, 1)
        'annotation.language.language_instruction': prompt
    }



    predicted_action = policy.get_action(step_data)
    pos = predicted_action['action.eef_position_delta']          # (16, 3)
    rot = predicted_action['action.eef_rotation_delta']          # (16, 3)
    grip = predicted_action['action.gripper_position']           # (16,)
    if grip.ndim == 1:
        grip = grip[:, np.newaxis]

    action_concat = np.concatenate([pos, rot, grip], axis=-1)
    print("Concatenated action:", action_concat.shape)
    action_list = action_concat.tolist()  # convert to list for sending
    


In [None]:
import imageio
import cv2


gif_path = "gr00t_deploy_2.gif"
imageio.mimsave(gif_path, video_buffer, duration=0.5)
print(f"GIF saved:{gif_path}")

In [None]:
from IPython.display import Image as IPyImage
IPyImage(filename="gr00t_deploy_2.gif")
