In [None]:
import lerobot
from lerobot.policies.factory import make_policy
from lerobot.configs.train import TrainPipelineConfig, PreTrainedConfig
from lerobot.policies.pi05 import (  # noqa: E402
    PI05Config,
    PI05Policy,
    make_pi05_pre_post_processors,  # noqa: E402
)
import lerobot.policies.pi05
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from pprint import pprint
import torch

In [None]:
import gc

if "policy" in locals():
    del policy
gc.collect()

# Mock data

In [None]:
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="bfloat16", device="cuda")

    # Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature

config.input_features = {
    "observation.state": PolicyFeature(
        type=FeatureType.STATE,
        shape=(14,),
    ),
    "observation.images.base_1_rgb": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(3, 224, 224),
    ),
}

config.output_features = {
    "action": PolicyFeature(
        type=FeatureType.ACTION,
        shape=(7,),
    ),
}

In [None]:
policy = PI05Policy(config)


In [None]:
import torch

batch_size = 1
device = "cuda"
dataset_stats = {
        "observation.state": {
            "mean": torch.zeros(14),
            "std": torch.ones(14),
            "min": torch.zeros(14),
            "max": torch.ones(14),
            "q01": torch.zeros(14),
            "q99": torch.ones(14),
        },
        "action": {
            "mean": torch.zeros(7),
            "std": torch.ones(7),
            "min": torch.zeros(7),
            "max": torch.ones(7),
            "q01": torch.zeros(7),
            "q99": torch.ones(7),
        },
        "observation.images.base_1_rgb": {
            "mean": torch.zeros(3, 224, 224),
            "std": torch.ones(3, 224, 224),
            "q01": torch.zeros(3, 224, 224),
            "q99": torch.ones(3, 224, 224),
        },
    }
preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
batch = {
        "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
        "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
        "observation.images.base_1_rgb": torch.rand(
            batch_size, 3, 224, 224, dtype=torch.float32, device=device
        ),  # Use rand for [0,1] range
        "task": ["Pick up the object"] * batch_size,
    }

In [None]:
input = preprocessor(batch)
input

In [None]:
action = policy.select_action(input)
output = postprocessor(action)
output

# Xarm

In [None]:
import mujoco
from pathlib import Path
import gym_lite6
import gymnasium as gym
import gym_lite6.env, gym_lite6.scripted_policy, gym_lite6.pickup_task
import mediapy as media
import numpy as np


from importlib import reload

reload(gym_lite6.env)
reload(gym_lite6.utils)
reload(gym_lite6.scripted_policy)
reload(gym_lite6.pickup_task)

# task = gym_lite6.pickup_task.GraspTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')
task = gym_lite6.pickup_task.GraspAndLiftTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')

env = gym.make(
    "UfactoryCubePickup-v0",
    task=task,
    obs_type="pixels_state",
    max_episode_steps=500,
    visualization_width=320,
    visualization_height=240,
    render_fps=30,
    joint_noise_magnitude=0.1
)


observation, info = env.reset()
media.show_image(env.unwrapped.render(camera="side_cam"))


In [None]:
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata

dataset_path= "/media/ssd/eugene/robotic_manipulation/lerobot_tests/datasets/lite6_record_scripted_250622"
dataset_meta = LeRobotDatasetMetadata(dataset_path)
dataset_meta.stats

In [None]:
dataset_meta.features

In [None]:
config = PI05Config(max_action_dim=7, max_state_dim=7, dtype="bfloat16", device="cuda")

    # Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature

config.input_features = {
    "observation.state": PolicyFeature(
        type=FeatureType.STATE,
        shape=(7,),
    ),
    "observation.images.side": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(240, 320, 3),
    ),
    "observation.images.gripper": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(240, 320, 3),
    ),
}

config.output_features = {
    "action": PolicyFeature(
        type=FeatureType.ACTION,
        shape=(7,),
    ),
}

preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_meta.stats)


In [None]:
numpy_observation, info = env.reset()


In [None]:
def numpy_to_torch_obs(numpy_observation):
    observation = {}
    observation["observation.state"] = torch.from_numpy(np.float32(np.hstack((numpy_observation["state"]["qpos"], numpy_observation["state"]["gripper"])))).unsqueeze(0).to(config.device)
    # DIVIDE BY 255
    observation["observation.images.side"] = torch.from_numpy(numpy_observation['pixels']['side']).permute((2,0,1)).unsqueeze(0).to(config.device)/255
    observation["observation.images.gripper"] = torch.from_numpy(numpy_observation['pixels']['gripper']).permute((2,0,1)).unsqueeze(0).to(config.device)/255
    return observation
observation = numpy_to_torch_obs(numpy_observation)
observation["task"] = ["Pick up the red cube"]
observation


In [None]:
policy = PI05Policy(config)


In [None]:
observation = numpy_to_torch_obs(numpy_observation)
observation["task"] = ["Pick up the red cube"]
observation = preprocessor(observation)
policy.select_action(observation)
postprocessor(action)

In [None]:
policy.reset()
policy.eval()
numpy_observation, info = env.reset()
rewards = []
frames = [numpy_observation["pixels"]["side"].squeeze()]
done = False
observation = {}
step = 0

ep_dict = {"action.qpos": [], "action.gripper": [], "observation.state.qpos": [], "observation.state.qvel": [], "observation.state.gripper": [], "observation.images.side": [], "observation.images.gripper": [], "reward": [], "timestamp": [], "frame_index": [],}
while not done:
    observation = numpy_to_torch_obs(numpy_observation)
    observation["task"] = ["Pick up the red cube"]
    observation = preprocessor(observation)
    with torch.inference_mode():
        action = policy.select_action(observation)
        action = postprocessor(action)[0]
    action = {"qpos": action[:env.unwrapped.dof], "gripper": round(np.clip(action[-1].item(), -1, 1))}
    numpy_observation, reward, terminated, truncated, info = env.step(action)

    rewards.append(reward)
    frames.append(numpy_observation["pixels"]["side"].squeeze())

    ep_dict["action.qpos"].append(action["qpos"])
    ep_dict["action.gripper"].append(action["gripper"])
    ep_dict["observation.state.qpos"].append(numpy_observation["state"]["qpos"])
    ep_dict["observation.state.qvel"].append(numpy_observation["state"]["qvel"])
    ep_dict["observation.state.gripper"].append(numpy_observation["state"]["gripper"])
    ep_dict["observation.images.side"].append(numpy_observation["pixels"]["side"])
    ep_dict["observation.images.gripper"].append(numpy_observation["pixels"]["gripper"])
    ep_dict["reward"].append(reward)
    ep_dict["timestamp"].append(env.unwrapped.data.time)
    ep_dict["frame_index"].append(step)

    done = terminated | truncated

In [None]:
import mediapy as media

media.show_video(frames)

# Finetune

In [None]:
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.factory import make_dataset
from lerobot.policies.factory import make_policy

dataset_path= "/media/ssd/eugene/robotic_manipulation/lerobot_tests/datasets/lite6_record_scripted_250622"
dataset_name = dataset_path.split('/')[-1]
dataset_meta = LeRobotDatasetMetadata(dataset_path)

In [None]:
# from lerobot.policies.pretrained import PreTrainedPolicy

policy_config = PI05Config(
    max_action_dim=7,
    max_state_dim=7,
    dtype="bfloat16",
    device="cuda",
    scheduler_decay_steps=3000,
    compile_model=True,
    gradient_checkpointing=True,
    push_to_hub=False
    )

    # Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature

policy_config.input_features = {
    "observation.state": PolicyFeature(
        type=FeatureType.STATE,
        shape=(7,),
    ),
    "observation.images.side": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(240, 320, 3),
    ),
    "observation.images.gripper": PolicyFeature(
        type=FeatureType.VISUAL,
        shape=(240, 320, 3),
    ),
}

policy_config.output_features = {
    "action": PolicyFeature(
        type=FeatureType.ACTION,
        shape=(7,),
    ),
}

preprocessor, postprocessor = make_pi05_pre_post_processors(config=policy_config, dataset_stats=dataset_meta.stats)


In [None]:
import datetime
from pathlib import Path

now = datetime.datetime.now()
job_name = "pi05_training"
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{job_name}"
output_dir = Path("outputs") / dataset_name / train_dir

cfg = TrainPipelineConfig(
    dataset=DatasetConfig(dataset_path),
    policy=policy_config,
    batch_size=1,
    steps=3000,
    job_name=job_name,
    num_workers=2,
    output_dir=output_dir
)
cfg.validate()


In [None]:
# IMPORTANT: use make_dataset to automatically handle action_delta_timestamps

dataset = make_dataset(cfg)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,
    batch_size=cfg.batch_size,
    shuffle=True, # TODO: episode aware sampler?
    sampler=None,
    pin_memory=False, #no need on unified mem cfg.policy.device == "cuda",
    drop_last=False,
    prefetch_factor=2 if cfg.num_workers > 0 else None,
)

In [None]:
policy = make_policy(cfg.policy, dataset.meta)


In [None]:
from lerobot.optim.factory import make_optimizer_and_scheduler
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
optimizer

In [None]:
from lerobot.datasets.utils import cycle
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker

step = 0  # number of policy updates (forward + backward + optim)
dl_iter = cycle(dataloader)


train_metrics = {
        "loss": AverageMeter("loss", ":.3f"),
        "grad_norm": AverageMeter("grdn", ":.3f"),
        "lr": AverageMeter("lr", ":0.1e"),
        "update_s": AverageMeter("updt_s", ":.3f"),
        "dataloading_s": AverageMeter("data_s", ":.3f"),
    }

train_tracker = MetricsTracker(
        cfg.batch_size,
        dataset.num_frames,
        dataset.num_episodes,
        train_metrics,
        initial_step=step,
        accelerator=None,
    )

In [None]:
import time
from typing import Any
from torch.optim import Optimizer
import logging
from lerobot.utils.train_utils import (
    get_step_checkpoint_dir,
    get_step_identifier,
    load_training_state,
    save_checkpoint,
    update_last_checkpoint,
)

def update_policy(
    train_metrics: MetricsTracker,
    policy,
    batch: Any,
    optimizer: Optimizer,
    grad_clip_norm: float,
    lr_scheduler,
    lock=None,
) -> tuple[MetricsTracker, dict]:
    start_time = time.perf_counter()
    policy.train()

    # Let accelerator handle mixed precision
    # with accelerator.autocast():
    #     loss, output_dict = policy.forward(batch)
    #     # TODO(rcadene): policy.unnormalize_outputs(out_dict)

    # loss, output_dict = policy.forward(batch)
    # loss.backward()
    loss = 1
    output_dict={}

    # Clip gradients if specified

    # grad_norm = torch.nn.utils.clip_grad_norm_(
    #     policy.parameters(), float("inf"), error_if_nonfinite=False
    # )

    # optimizer.step()
    # optimizer.zero_grad()

    # Step through pytorch scheduler at every batch instead of epoch
    if lr_scheduler is not None:
        lr_scheduler.step()

    # Update internal buffers if policy has update method
    # if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
    #     accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()

    train_metrics.loss = loss.item()
    train_metrics.grad_norm = grad_norm.item()
    train_metrics.lr = optimizer.param_groups[0]["lr"]
    train_metrics.update_s = time.perf_counter() - start_time
    return train_metrics, output_dict

def create_state(batch):
    batch["observation.state"] = torch.cat((batch["observation.state.qpos"], batch["observation.state.gripper"].unsqueeze(1)), dim=1)

for _ in range(step, cfg.steps):
    start_time = time.perf_counter()
    batch = next(dl_iter)
    create_state(batch)
    batch = preprocessor(batch)

    train_tracker.dataloading_s = time.perf_counter() - start_time

    train_tracker, output_dict = update_policy(
                train_tracker,
                policy,
                batch,
                optimizer,
                0,
                lr_scheduler=lr_scheduler,
            )
    step += 1
    train_tracker.step()
    logging.info(train_tracker)

    is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
    is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
    is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0

    if is_log_step:
            logging.info(train_tracker)
            train_tracker.reset_averages()

    if cfg.save_checkpoint and is_saving_step:
        logging.info(f"Checkpoint policy after step {step}")
        checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
        save_checkpoint(
            checkpoint_dir=checkpoint_dir,
            step=step,
            cfg=cfg,
            policy=policy,
            optimizer=optimizer,
            scheduler=lr_scheduler,
            preprocessor=preprocessor,
            postprocessor=postprocessor,
        )
        update_last_checkpoint(checkpoint_dir)


        # if cfg.env and is_eval_step:
        #     step_id = get_step_identifier(step, cfg.steps)
        #     logging.info(f"Eval policy at step {step}")
        #     with torch.no_grad():
        #         eval_info = eval_policy_all(
        #             envs=eval_env,  # dict[suite][task_id] -> vec_env
        #             policy=accelerator.unwrap_model(policy),
        #             preprocessor=preprocessor,
        #             postprocessor=postprocessor,
        #             n_episodes=cfg.eval.n_episodes,
        #             videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
        #             max_episodes_rendered=4,
        #             start_seed=cfg.seed,
        #             max_parallel_tasks=cfg.env.max_parallel_tasks,
        #         )
        #     # overall metrics (suite-agnostic)
        #     aggregated = eval_info["overall"]

        #     # optional: per-suite logging
        #     for suite, suite_info in eval_info.items():
        #         logging.info("Suite %s aggregated: %s", suite, suite_info)

        #     # meters/tracker
        #     eval_metrics = {
        #         "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
        #         "pc_success": AverageMeter("success", ":.1f"),
        #         "eval_s": AverageMeter("eval_s", ":.3f"),
        #     }
        #     eval_tracker = MetricsTracker(
        #         cfg.batch_size,
        #         dataset.num_frames,
        #         dataset.num_episodes,
        #         eval_metrics,
        #         initial_step=step,
        #     )
        #     eval_tracker.eval_s = aggregated.pop("eval_s")
        #     eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
        #     eval_tracker.pc_success = aggregated.pop("pc_success")



In [None]:
cfg.output_dir