In [6]:
from gymnasium import spaces
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
import os
from datasets import load_dataset
import os
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.pi0.modeling_pi0 import PI0Policy

from moviepy import ImageSequenceClip
from IPython.display import HTML
from base64 import b64encode
import tempfile
import json
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image
import simpler_env
from simpler_env.utils.env.observation_utils import get_image_from_maniskill3_obs_dict
from gymnasium import spaces
import gymnasium as gym
from mani_skill.envs.tasks.digital_twins.bridge_dataset_eval import *

In [7]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [8]:
def auto_model_from_pretrained(path, **kwargs):
    import sys
    sys.path.append(path)
    map_location = kwargs.pop("map_location", "cpu")
    return PI0Policy.from_pretrained(path, **kwargs).to(map_location)

saved_model_path = "/home/user10_2/lerobot-pi0-bridge"

policy = auto_model_from_pretrained(saved_model_path, map_location="cuda:0")


Loading weights from local directory


In [22]:
with open('/home/user10_2/stats.json', 'r', encoding='utf-8') as f:
    stats = json.load(f)

In [23]:
def prepare_observation_for_policy(obs: dict, device: str) -> dict:
    
    observation = {}

    observation = {
        "observation.state": obs['state'].to(device),
    }
    
    if "task" in obs:
        observation["task"] = obs["task"]
    else:
        batch_size = len(obs["rgb"])
        observation["task"] = ["stack the green block on the yellow block"] * batch_size
    observation["observation.images.image_0"] = obs["rgb"].permute(0, 3, 1, 2).to(torch.float32).to(device)
        
    return observation

In [24]:
def inject_normalization_stats(policy, stats):
    pol_state_dict = policy.state_dict()

    print("Available stats keys:", list(stats.keys()))

    keys_to_update = {
        "normalize_inputs.buffer_observation_state.mean": ("observation.state", "mean"),
        "normalize_inputs.buffer_observation_state.std": ("observation.state", "std"),
        "normalize_targets.buffer_action.mean": ("action", "mean"),
        "normalize_targets.buffer_action.std": ("action", "std"),
        "unnormalize_outputs.buffer_action.mean": ("action", "mean"),
        "unnormalize_outputs.buffer_action.std": ("action", "std"),
    }

    updated_count = 0
    for pol_key, (stat_key, stat_type) in keys_to_update.items():
        if pol_key in pol_state_dict and stat_key in stats:
            pol_state_dict[pol_key] = torch.from_numpy(np.array(stats[stat_key][stat_type]))
            updated_count += 1
        else:
            print(f"Could not find {pol_key} or {stat_key}")

    policy.load_state_dict(pol_state_dict)
    print("Normalization stats injected into the policy.")

In [25]:
def get_action_chunk(policy, batch, device, actions_per_chunk=10):
    with torch.no_grad():
        batch_processed = prepare_observation_for_policy(batch, device)
        batch_normalized = policy.normalize_inputs(batch_processed)
        
        images, img_masks = policy.prepare_images(batch_normalized)
        state = policy.prepare_state(batch_normalized)
        lang_tokens, lang_masks = policy.prepare_language(batch_normalized)
        
        
        actions = policy.model.sample_actions(
            images, img_masks, lang_tokens, lang_masks, state
        )
        
        original_action_dim = policy.config.action_feature.shape[0]
        actions = actions[:, :actions_per_chunk, :original_action_dim]
        
        return actions

In [26]:
def tensors_to_video_jupyter(tensors: list[torch.Tensor], fps: int = 30):
    frames = [tensor.squeeze(0).cpu().numpy().astype('uint8') for tensor in tensors]

    clip = ImageSequenceClip(frames, fps=fps)

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
        temp_filename = tmpfile.name
        clip.write_videofile(temp_filename, codec="libx264", audio=False)

    with open(temp_filename, "rb") as f:
        video_data = f.read()
    os.remove(temp_filename)

    encoded = b64encode(video_data).decode('utf-8')
    video_html = f'''
    <video width="640" height="480" controls>
        <source src="data:video/mp4;base64,{encoded}" type="video/mp4">
        Your browser does not support the video tag.
    </video>
    '''
    return HTML(video_html)

In [27]:
inject_normalization_stats(policy, stats)
policy.eval()

Available stats keys: ['observation.images.image_3', 'observation.images.image_2', 'observation.images.image_1', 'observation.images.image_0', 'observation.state', 'action', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index']
Normalization stats injected into the policy.


PI0Policy(
  (normalize_inputs): Normalize(
    (buffer_observation_state): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 8 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 8 (cuda:0)]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
    )
  )
  (model): PI0FlowMatching(
    (paligemma_with_expert): PaliGemmaWithExpertModel(
      (paligemma): PaliGemmaForConditionalGeneration(
        (model): PaliGemmaModel(
          (vision_tower): SiglipVisionModel(
            (vision_model): Sig

In [28]:
class ObsWrapper(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

        base_cam_space = env.observation_space.spaces['sensor_data'].spaces['3rd_view_camera'].spaces['rgb']
        base_state_space = env.observation_space.spaces['agent'].spaces['qpos']
        self.observation_space = spaces.Dict({
            'rgb': base_cam_space,
            'state': base_state_space
        })

    def observation(self, obs: dict) -> dict:
        new_obs = {
            'rgb': obs['sensor_data']['3rd_view_camera']['rgb'],
            'state': obs['agent']['qpos']
        }
        return new_obs

In [49]:
env = gym.make(
  "PutSpoonOnTableClothInScene-v1",
  obs_mode="rgb+segmentation",
  num_envs=1, # if num_envs > 1, GPU simulation backend is used.
)


env = ObsWrapper(env)
obs, _ = env.reset()
# returns language instruction for each parallel env
instruction = env.unwrapped.get_language_instruction()
print("instruction:", instruction[0])



instruction: put the spoon on the towel


In [44]:
obs, _ = env.reset(seed = 0)
obs_list = []
ACTIONS_PER_CHUNK = 4
done = False
actions = get_action_chunk(policy, obs, "cuda:0", ACTIONS_PER_CHUNK)
count_act = 0
i = 0
while not done:
  if count_act == ACTIONS_PER_CHUNK:
    count_act = 0
    actions = get_action_chunk(policy, obs, "cuda:0", ACTIONS_PER_CHUNK)
  obs_list.append(obs['rgb'].detach().cpu().clone())
  obs, reward, terminated, truncated, info = env.step(actions[:,count_act,:][0])
  done = torch.logical_or(terminated, truncated)
  count_act+=1
  i+=1
tensors_to_video_jupyter(obs_list, fps = 30)

MoviePy - Building video /tmp/tmp7g0_wgbf.mp4.
MoviePy - Writing video /tmp/tmp7g0_wgbf.mp4



                                                             

MoviePy - Done !
MoviePy - video ready /tmp/tmp7g0_wgbf.mp4




In [50]:
obs, _ = env.reset()
obs_list = []
for i in range(50):
  obs_list.append(obs['rgb'])
  action = env.action_space.sample() # replace this with your policy inference
  obs, reward, terminated, truncated, info = env.step(action)
  if truncated.any() or terminated.any():
      break
tensors_to_video_jupyter(obs_list, fps = 30)

MoviePy - Building video /tmp/tmphmishz8w.mp4.
MoviePy - Writing video /tmp/tmphmishz8w.mp4



                                                             

MoviePy - Done !
MoviePy - video ready /tmp/tmphmishz8w.mp4


