In [1]:
from gymnasium import spaces
import gymnasium as gym
import mani_skill
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
import os
from datasets import load_dataset
import matplotlib.pyplot as plt
import os
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from datasets import load_dataset
from lerobot.policies.pi0.modeling_pi0 import PI0Policy

from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

from moviepy import ImageSequenceClip
from IPython.display import HTML
from base64 import b64encode
import tempfile
import json
import numpy as np
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

In [3]:
seed = 0
env = gym.make(
    "PickCube-v1",
    control_mode="pd_ee_delta_pose",
    num_envs=1,
    obs_mode="rgb",
    render_mode="rgb_array",
    sim_backend="gpu",
    sensor_configs={"width": 224, "height": 224},
)
env = FlattenRGBDObservationWrapper(env, rgb=True, depth=False, state=True)
if isinstance(env.action_space, gym.spaces.Dict):
    env = FlattenActionSpaceWrapper(env)
env = ManiSkillVectorEnv(env, 1, ignore_terminations=True, record_metrics=True)
env.action_space.seed(seed)



[0]

In [4]:
POLICY_PATH = "dancher00/pi0-panda-pickcube-11k-steps"
DATASET_PATH = "/home/user10_2/maniskill-panda-pickcube"

In [5]:
dataset = LeRobotDataset(DATASET_PATH, video_backend="pyav")
policy = PI0Policy.from_pretrained(POLICY_PATH).to("cuda:0")

The dataset you requested (/home/user10_2/maniskill-panda-pickcube) is in 2.0 format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=/home/user10_2/maniskill-panda-pickcube
```

If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).



In [6]:
def prepare_observation_for_policy(obs: dict, device: str) -> dict:
    
    observation = {
        "observation.state": obs['state'][:,:7].to(device),
    }
    
    if "task" in obs:
        observation["task"] = obs["task"]
    else:
        batch_size = len(obs["state"])
        observation["task"] = ["Pick the cube to the target position."] * batch_size
    
    observation["observation.images.main"] = obs["rgb"].permute(0, 3, 1, 2).to(device)
        
    return observation

In [None]:
def inject_normalization_stats(policy, dataset):
    stats = dataset.meta.stats
    print(stats)
    pol_state_dict = policy.state_dict()

    print("Available stats keys:", list(stats.keys()))

    keys_to_update = {
        "normalize_inputs.buffer_observation_state.mean": ("observation.state", "mean"),
        "normalize_inputs.buffer_observation_state.std": ("observation.state", "std"),
        "normalize_targets.buffer_action.mean": ("action", "mean"),
        "normalize_targets.buffer_action.std": ("action", "std"),
        "unnormalize_outputs.buffer_action.mean": ("action", "mean"),
        "unnormalize_outputs.buffer_action.std": ("action", "std"),
    }

    updated_count = 0
    for pol_key, (stat_key, stat_type) in keys_to_update.items():
        if pol_key in pol_state_dict and stat_key in stats:
            pol_state_dict[pol_key] = torch.from_numpy(stats[stat_key][stat_type])
            updated_count += 1
        else:
            print(f"Could not find {pol_key} or {stat_key}")

    policy.load_state_dict(pol_state_dict)
    print("Normalization stats injected into the policy.")

In [7]:
def get_action_chunk(policy, batch, device, actions_per_chunk=10):
    with torch.no_grad():
        batch_processed = prepare_observation_for_policy(batch, device)
        batch_normalized = policy.normalize_inputs(batch_processed)
        
        images, img_masks = policy.prepare_images(batch_normalized)
        state = policy.prepare_state(batch_normalized)
        lang_tokens, lang_masks = policy.prepare_language(batch_normalized)
        
        
        actions = policy.model.sample_actions(
            images, img_masks, lang_tokens, lang_masks, state
        )
        
        original_action_dim = policy.config.action_feature.shape[0]
        actions = actions[:, :actions_per_chunk, :original_action_dim]
        
        return actions

In [8]:
def tensors_to_video_jupyter(tensors: list[torch.Tensor], fps: int = 30):
    frames = [tensor.squeeze(0).cpu().numpy().astype('uint8') for tensor in tensors]

    clip = ImageSequenceClip(frames, fps=fps)

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
        temp_filename = tmpfile.name
        clip.write_videofile(temp_filename, codec="libx264", audio=False)

    with open(temp_filename, "rb") as f:
        video_data = f.read()
    os.remove(temp_filename)

    encoded = b64encode(video_data).decode('utf-8')
    video_html = f'''
    <video width="640" height="480" controls>
        <source src="data:video/mp4;base64,{encoded}" type="video/mp4">
        Your browser does not support the video tag.
    </video>
    '''
    return HTML(video_html)

In [9]:
inject_normalization_stats(policy, dataset)
policy.eval()

Available stats keys: ['action', 'observation.state', 'observation.images.main', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index']
Normalization stats injected into the policy.


PI0Policy(
  (normalize_inputs): Normalize(
    (buffer_observation_state): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
    )
  )
  (model): PI0FlowMatching(
    (paligemma_with_expert): PaliGemmaWithExpertModel(
      (paligemma): PaliGemmaForConditionalGeneration(
        (model): PaliGemmaModel(
          (vision_tower): SiglipVisionModel(
            (vision_model): Sig

In [12]:
obs, _ = env.reset(seed = 0)
obs_list = []
ACTIONS_PER_CHUNK = 50
actions = get_action_chunk(policy, obs, "cuda:0", ACTIONS_PER_CHUNK)
for i in range(50):
  obs_list.append(env.render().detach().cpu().clone())
  obs, reward, terminated, truncated, info = env.step(actions[:,i,:])
  print(actions[:,i,:])
  if truncated.any() or terminated.any():
      break
tensors_to_video_jupyter(obs_list, fps = 30)

tensor([[-0.1689, -0.1388,  0.2258,  0.3103,  0.0746, -0.0060, -0.3265]],
       device='cuda:0')
tensor([[-0.1481, -0.1335,  0.2111,  0.3284,  0.0855,  0.0622, -0.3257]],
       device='cuda:0')
tensor([[-0.1238, -0.1059,  0.2256,  0.3355,  0.0595,  0.0486, -0.3382]],
       device='cuda:0')
tensor([[-0.1584, -0.1330,  0.2786,  0.3259,  0.1063,  0.0655, -0.3491]],
       device='cuda:0')
tensor([[-0.1327, -0.1313,  0.2436,  0.3516,  0.1241,  0.0489, -0.3334]],
       device='cuda:0')
tensor([[-0.1286, -0.1473,  0.2560,  0.3614,  0.1016,  0.0646, -0.3013]],
       device='cuda:0')
tensor([[-0.1590, -0.1396,  0.2659,  0.3544,  0.1023,  0.0946, -0.3171]],
       device='cuda:0')
tensor([[-0.1393, -0.1497,  0.2839,  0.3490,  0.0967,  0.0902, -0.3304]],
       device='cuda:0')
tensor([[-0.1214, -0.1486,  0.2834,  0.3454,  0.1257,  0.0748, -0.2975]],
       device='cuda:0')
tensor([[-0.1237, -0.1504,  0.3023,  0.3592,  0.1045,  0.0612, -0.3459]],
       device='cuda:0')
tensor([[-0.1453, -0

                                                             

MoviePy - Done !
MoviePy - video ready /tmp/tmprii2rqqa.mp4




In [None]:
# obs, _ = env.reset(seed = 0)
# obs_list = []
# ACTIONS_PER_CHUNK = 1
# for i in range(50):
#   obs_list.append(env.render().clone())
#   action = get_action_chunk(policy, obs, "cuda:1", ACTIONS_PER_CHUNK)
#   obs, reward, terminated, truncated, info = env.step(action.squeeze(0))
#   if truncated.any() or terminated.any():
#       break

# tensors_to_video_jupyter(obs_list, fps = 30)

In [12]:
processor = AutoProcessor.from_pretrained("MINT-SJTU/RoboFAC-7B", trust_remote_code=True, torch_dtype=torch.float16)
model = AutoModelForVision2Seq.from_pretrained("MINT-SJTU/RoboFAC-7B", trust_remote_code=True, torch_dtype=torch.float16).to("cuda:1")

processor.patch_size = model.config.vision_config.patch_size
processor.num_additional_image_tokens = getattr(model.config.vision_config, "num_additional_image_tokens", 0)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.28it/s]


In [13]:
prompt_high_level = 'From the video of the robotic arm failing during the task, provide high-level corrective commands to guide it to recover and finish the task.'
prompt_low_level = 'This is a video of a robotic arm performing a task, an error occurred during execution. Please provide low-level corrective commands to help the robot recover and complete the task successfully.'

In [14]:
def get_correction(prompt, processor, obs_list, device):
    obs_list = [obs.to(device) for obs in obs_list]
    messages = [
            {
            "role": "user",
            "content": [
                *[{"type": "image", "image": img} for img in obs_list],
                {"type": "text", "text": prompt}
            ],
            }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[text],
        images=obs_list,
        padding=True,
        return_tensors="pt",
    ).to(device)
    output = processor.batch_decode(model.generate(**inputs, max_new_tokens=2048), skip_special_tokens=True)[0]
    return output.split('\nassistant\n', 1)[-1]

In [None]:
obs, _ = env.reset(seed = 0)
obs_list = []
ACTIONS_PER_CHUNK = 20
max_steps = 50
actions = get_action_chunk(policy, obs, "cuda:0", ACTIONS_PER_CHUNK)
for step in range(max_steps):
  if step == ACTIONS_PER_CHUNK:
    correction = get_correction(prompt_low_level, processor, obs_list, "cuda:1")
    obs['task'] = [correction]
    actions = get_action_chunk(policy, obs, "cuda:0", max_steps - ACTIONS_PER_CHUNK)
  obs_list.append(env.render().detach().cpu().clone())
  obs, reward, terminated, truncated, info = env.step(actions[:,step - ACTIONS_PER_CHUNK,:])
  if truncated.any() or terminated.any():
      break
tensors_to_video_jupyter(obs_list, fps = 30)

MoviePy - Building video /tmp/tmpapsiq1mw.mp4.
MoviePy - Writing video /tmp/tmpapsiq1mw.mp4



                                                             

MoviePy - Done !
MoviePy - video ready /tmp/tmpapsiq1mw.mp4




In [None]:
save_dir = '/home/user10_2/reason_pi0'
os.makedirs(save_dir, exist_ok=True)
num_episodes = 2
successes = []
count_ep = 0
for ep in range(num_episodes):
    rgbList, jointsList,actList, rewList, succList, doneList = [], [], [], [], [], []
    obs, _ = env.reset(seed = ep)
    ACTIONS_PER_CHUNK = 50
    actions = get_action_chunk(policy, obs, "cuda:0", ACTIONS_PER_CHUNK)
    for i in range(50):
        rgbList.append(obs['rgb'].cpu().numpy())
        jointsList.append(obs['state'].cpu().numpy())
        obs, reward, terminated, truncated, info = env.step(actions[:,i,:])
        rewList.append(reward.cpu().numpy())
        succList.append(info['success'].cpu().numpy().astype(int))
        actList.append(actions[:,i,:].cpu().numpy())
        done = torch.logical_or(terminated, truncated)
        doneList.append(done.cpu().numpy().astype(int))
    successes.append(info.get("is_success", False))
    if successes[-1]:
        DATA = {'rgb': np.vstack(rgbList),
                'joints': np.vstack(jointsList),
                'action': np.vstack(actList),
                'reward': np.squeeze(np.vstack(rewList)),
                'success': np.squeeze(np.array(succList)),
                'done': np.squeeze(np.array(doneList))}


        file_path = f'{save_dir}/train_data_{count_ep}.npz'
        np.savez(file_path, **DATA)
        count_ep+=1

In [None]:
for i, success in enumerate(successes):
    if success == True:
        continue
    obs, _ = env.reset(seed = i)
    rgbList, jointsList,actList, rewList, succList, doneList = [], [], [], [], [], []
    ACTIONS_PER_CHUNK = 20
    max_steps = 50
    actions = get_action_chunk(policy, obs, "cuda:0", ACTIONS_PER_CHUNK)
    for step in range(max_steps):
        rgbList.append(obs['rgb'].cpu().numpy())
        jointsList.append(obs['state'].cpu().numpy())
        if step == ACTIONS_PER_CHUNK:
            correction = get_correction(prompt_high_level, processor, obs_list, "cuda:1")
            obs['task'] = [correction]
            actions = get_action_chunk(policy, obs, "cuda:0", max_steps - ACTIONS_PER_CHUNK)
        obs, reward, terminated, truncated, info = env.step(actions[:,step - ACTIONS_PER_CHUNK,:])
        rewList.append(reward.cpu().numpy())
        succList.append(info['success'].cpu().numpy().astype(int))
        actList.append(actions[:,i,:].cpu().numpy())
        done = torch.logical_or(terminated, truncated)
        doneList.append(done.cpu().numpy().astype(int))
    if info.get("is_success", False):
        DATA = {'rgb': np.vstack(rgbList),
                'joints': np.vstack(jointsList),
                'action': np.vstack(actList),
                'reward': np.squeeze(np.vstack(rewList)),
                'success': np.squeeze(np.array(succList)),
                'done': np.squeeze(np.array(doneList))}

        file_path = f'{save_dir}/train_data_{count_ep}.npz'
        np.savez(file_path, **DATA)
        count_ep+=1