In [1]:
!pip install ninja "packaging>=24.2,<26.0"
!pip install peft
!pip install dm-tree==0.1.9
!pip install -U transformers
!pip install flash-attn==2.7.3 --no-build-isolation

Collecting peft
  Downloading peft-0.18.1-py3-none-any.whl.metadata (14 kB)
Downloading peft-0.18.1-py3-none-any.whl (556 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.0/557.0 kB[0m [31m66.5 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: peft
Successfully installed peft-0.18.1
Collecting dm-tree==0.1.9
  Downloading dm_tree-0.1.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting wrapt>=1.11.2 (from dm-tree==0.1.9)
  Downloading wrapt-2.0.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (9.0 kB)
Downloading dm_tree-0.1.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (153 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.0/153.0 kB[0m [31m26.2 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading wrapt-2.0.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl (121 kB)
[2K   [9

In [2]:
import random
import numpy as np
import os
import torch
import json
from PIL import Image
from src.env.env import RILAB_OMY_ENV
from torchvision import transforms

from lerobot.policies.groot.modeling_groot import GrootPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
    batch_to_transition,
    policy_action_to_transition,
    transition_to_batch,
    transition_to_policy_action,
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME


import glfw

  from .autonotebook import tqdm as notebook_tqdm


## Load Model

In [None]:
'''
Meta data is for loading dataset statistics and feature information
'''
repo_id_or_path = 'ckpt/tutorial_v2_groot' #'Jeongeun/tutorial_v2_groot'
device = 'cuda'

# dataset_metadata = LeRobotDatasetMetadata("Jeongeun/tutorial_v2", root=ROOT)
policy = GrootPolicy.from_pretrained(repo_id_or_path)
policy.to(device)

[GROOT] Flash Attention version: 2.7.3
Loading pretrained dual brain from nvidia/GR00T-N1.5-3B
Tune backbone vision tower: False
Tune backbone LLM: False
Tune action head projector: True
Tune action head DiT: False


Fetching 13 files:   8%|▊         | 1/13 [00:01<00:23,  1.93s/it]

In [None]:
# preprocessor, postprocessor = make_groot_pre_post_processors(
#         config=policy.config,
#         dataset_stats= dataset_metadata.stats
#     )
kwargs = {}
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
    "stats": kwargs.get("dataset_stats"),
    "normalize_min_max": True,
}

# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy.config.output_features["action"].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
    "stats": kwargs.get("dataset_stats"),
    "normalize_min_max": True,
    "env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides


preprocessor = PolicyProcessorPipeline.from_pretrained(
    pretrained_model_name_or_path=repo_id_or_path,
    config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json",
    overrides=kwargs.get("preprocessor_overrides", {}),
    to_transition=batch_to_transition,
    to_output=transition_to_batch,
)

postprocessor =  PolicyProcessorPipeline.from_pretrained(
    pretrained_model_name_or_path=repo_id_or_path,
    config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
    overrides=kwargs.get("postprocessor_overrides", {}),
    to_transition=policy_action_to_transition,
    to_output=transition_to_policy_action,
)

In [None]:
batch = {
    'observation.state': np.zeros((1, 8), dtype=np.float32),
    'observation.image': np.zeros((1, 3, 256, 256), dtype=np.float32),
    'observation.wrist_image': np.zeros((1, 3, 256, 256), dtype=np.float32),
    'task': ['move the object to the target position']
}
batch = preprocessor(batch)  # to initialize the processors
_ = policy.select_action(batch)  # to initialize the model

## Load Environment

In [None]:
'''
Load environment configuration and initialize environments
'''
# Evaluation Configuration
TEST_EPISODES = 20
MAX_EPISODE_STEPS = 10_000

In [None]:
config_file_path = './configs/train.json'
with open(config_file_path) as f:
    env_conf = json.load(f)
omy_env = RILAB_OMY_ENV(cfg=env_conf, seed=0, 
                        action_type='joint', 
                        obs_type='joint_pos',
                        vis_mode = 'teleop')

In [None]:
def get_default_transform():
    """
    Returns a torchvision transform that:
     Converts to a FloatTensor and scales pixel values [0,255] -> [0.0,1.0]
    """
    return transforms.Compose([
        transforms.ToTensor(),  # PIL [0–255] -> FloatTensor [0.0–1.0], shape C×H×W
    ])
IMG_TRANSFORM = get_default_transform()

## Evaluation

In [None]:
'''
Run one evaluation episode
'''
def run_one_episode():
    omy_env.reset()
    policy.reset()
    observation = omy_env.get_observation()
    omy_env.env.tick = 0
    while omy_env.env.is_viewer_alive() and omy_env.env.tick < MAX_EPISODE_STEPS:
        omy_env.step_env()
        if omy_env.env.loop_every(HZ = 20):
            success = omy_env.check_success()
            if success: break
            if omy_env.env.is_key_pressed_once(glfw.KEY_Z):
                break  # for debugging: press 'z' to end the episode
            agent_image, wrist_image = omy_env.grab_image(return_side=False)
            # # resize to 256x256
            frame = {
                "observation.state": observation,
            }
            agent_image = Image.fromarray(agent_image)
            wrist_image = Image.fromarray(wrist_image)
            agent_image = agent_image.resize((256, 256))
            wrist_image = wrist_image.resize((256, 256))
            agent_image = IMG_TRANSFORM(agent_image)
            wrist_image = IMG_TRANSFORM(wrist_image)
            frame["observation.image"] = agent_image
            frame["observation.wrist_image"] = wrist_image
            # numpy to torch
            frame = {k: torch.tensor(v, dtype=torch.float32).unsqueeze(0).to(device) for k, v in frame.items()}
            # pre-process the frame
            frame = preprocessor(frame)
            # select action
            action = policy.select_action(frame)
            # post-process the action
            action = postprocessor(action)
            action = action.squeeze(0).cpu().numpy()
            observation = omy_env.step(action, gripper_mode='continuous')
            omy_env.render()
    return success

In [None]:
'''
Run evaluation over multiple episodes
'''
results = []
for episode in range(TEST_EPISODES):
    success = run_one_episode()
    results.append(success)
    print(f"Episode {episode+1}/{TEST_EPISODES} - Success: {success}")
omy_env.env.close_viewer()
# log average success rate
avg_success = sum(results) / len(results)
print(f"Average Success Rate over {TEST_EPISODES} episodes: {avg_success*100:.2f}%")
