In [None]:
import lerobot
from lerobot.policies.factory import make_policy
from lerobot.configs.train import TrainPipelineConfig, PreTrainedConfig
from lerobot.policies.pi05 import (  # noqa: E402
    PI05Config,
    PI05Policy,
    make_pi05_pre_post_processors,  # noqa: E402
    
)
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from pprint import pprint
import torch

In [None]:
import gc

if "policy" in locals():
    del policy
gc.collect()

# Mock data

In [None]:
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="bfloat16", device="cuda")

    # Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature

config.input_features = {
    "observation.state": PolicyFeature(
        type=FeatureType.STATE,
        shape=(14,),
    ),
    "observation.images.base_1_rgb": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(3, 224, 224),
    ),
}

config.output_features = {
    "action": PolicyFeature(
        type=FeatureType.ACTION,
        shape=(7,),
    ),
}

In [None]:
policy = PI05Policy(config)


In [None]:
import torch

batch_size = 1
device = "cuda"
dataset_stats = {
        "observation.state": {
            "mean": torch.zeros(14),
            "std": torch.ones(14),
            "min": torch.zeros(14),
            "max": torch.ones(14),
            "q01": torch.zeros(14),
            "q99": torch.ones(14),
        },
        "action": {
            "mean": torch.zeros(7),
            "std": torch.ones(7),
            "min": torch.zeros(7),
            "max": torch.ones(7),
            "q01": torch.zeros(7),
            "q99": torch.ones(7),
        },
        "observation.images.base_1_rgb": {
            "mean": torch.zeros(3, 224, 224),
            "std": torch.ones(3, 224, 224),
            "q01": torch.zeros(3, 224, 224),
            "q99": torch.ones(3, 224, 224),
        },
    }
preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
batch = {
        "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
        "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
        "observation.images.base_1_rgb": torch.rand(
            batch_size, 3, 224, 224, dtype=torch.float32, device=device
        ),  # Use rand for [0,1] range
        "task": ["Pick up the object"] * batch_size,
    }

In [None]:
input = preprocessor(batch)
input

In [None]:
action = policy.select_action(input)
output = postprocessor(action)
output

# Xarm

In [None]:
import mujoco
from pathlib import Path
import gym_lite6
import gymnasium as gym
import gym_lite6.env, gym_lite6.scripted_policy, gym_lite6.pickup_task
import mediapy as media
import numpy as np


from importlib import reload

reload(gym_lite6.env)
reload(gym_lite6.utils)
reload(gym_lite6.scripted_policy)
reload(gym_lite6.pickup_task)

# task = gym_lite6.pickup_task.GraspTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')
task = gym_lite6.pickup_task.GraspAndLiftTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')

env = gym.make(
    "UfactoryCubePickup-v0",
    task=task,
    obs_type="pixels_state",
    max_episode_steps=500,
    visualization_width=320,
    visualization_height=240,
    render_fps=30,
    joint_noise_magnitude=0.1
)


observation, info = env.reset()
media.show_image(env.unwrapped.render(camera="side_cam"))


In [None]:
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata

dataset_path= "/media/ssd/eugene/robotic_manipulation/lerobot_tests/datasets/lite6_record_scripted_250622"
dataset_meta = LeRobotDatasetMetadata(dataset_path)
dataset_meta.stats

In [None]:
dataset_meta.features

In [None]:
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="bfloat16", device="cuda")

    # Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature

config.input_features = {
    "observation.state": PolicyFeature(
        type=FeatureType.STATE,
        shape=(7,),
    ),
    "observation.images.side": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(240, 320, 3),
    ),
    "observation.images.gripper": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(240, 320, 3),
    ),
}

config.output_features = {
    "action": PolicyFeature(
        type=FeatureType.ACTION,
        shape=(7,),
    ),
}

preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_meta.stats)


In [None]:
numpy_observation, info = env.reset()


In [None]:
def numpy_to_torch_obs(numpy_observation):
    observation = {}
    observation["observation.state"] = torch.from_numpy(np.float32(np.hstack((numpy_observation["state"]["qpos"], numpy_observation["state"]["gripper"])))).unsqueeze(0).to(config.device)
    # DIVIDE BY 255
    observation["observation.images.side"] = torch.from_numpy(numpy_observation['pixels']['side']).permute((2,0,1)).unsqueeze(0).to(config.device)/255
    observation["observation.images.gripper"] = torch.from_numpy(numpy_observation['pixels']['gripper']).permute((2,0,1)).unsqueeze(0).to(config.device)/255
    return observation
observation = numpy_to_torch_obs(numpy_observation)
observation["task"] = ["Pick up the red cube"]
observation


In [None]:
policy = PI05Policy(config)


In [None]:
observation = numpy_to_torch_obs(numpy_observation)
observation["task"] = ["Pick up the red cube"]
observation = preprocessor(observation)
policy.select_action(observation)
postprocessor(action)

In [None]:
policy.reset()
policy.eval()
numpy_observation, info = env.reset()
rewards = []
frames = [numpy_observation["pixels"]["side"].squeeze()]
done = False
observation = {}
step = 0

ep_dict = {"action.qpos": [], "action.gripper": [], "observation.state.qpos": [], "observation.state.qvel": [], "observation.state.gripper": [], "observation.images.side": [], "observation.images.gripper": [], "reward": [], "timestamp": [], "frame_index": [],}
while not done:
    observation = numpy_to_torch_obs(numpy_observation)
    observation["task"] = ["Pick up the red cube"]
    observation = preprocessor(observation)
    with torch.inference_mode():
        action = policy.select_action(observation)
        action = postprocessor(action)[0]
    action = {"qpos": action[:env.unwrapped.dof], "gripper": round(np.clip(action[-1].item(), -1, 1))}
    numpy_observation, reward, terminated, truncated, info = env.step(action)

    rewards.append(reward)
    frames.append(numpy_observation["pixels"]["side"].squeeze())

    ep_dict["action.qpos"].append(action["qpos"])
    ep_dict["action.gripper"].append(action["gripper"])
    ep_dict["observation.state.qpos"].append(numpy_observation["state"]["qpos"])
    ep_dict["observation.state.qvel"].append(numpy_observation["state"]["qvel"])
    ep_dict["observation.state.gripper"].append(numpy_observation["state"]["gripper"])
    ep_dict["observation.images.side"].append(numpy_observation["pixels"]["side"])
    ep_dict["observation.images.gripper"].append(numpy_observation["pixels"]["gripper"])
    ep_dict["reward"].append(reward)
    ep_dict["timestamp"].append(env.unwrapped.data.time)
    ep_dict["frame_index"].append(step)

    done = terminated | truncated

In [None]:
import mediapy as media

media.show_video(frames)

# Finetune

In [None]:
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata

dataset_path= "/media/ssd/eugene/robotic_manipulation/lerobot_tests/datasets/lite6_record_scripted_250622"
dataset_meta = LeRobotDatasetMetadata(dataset_path)
dataset_meta.stats