In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['OPENPI_DATA_HOME'] = '/data1/hogun/openpi'

In [2]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [4]:
config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"].shape)

  0%|          | 0.00/10.1G [00:00<?, ?iB/s]

ERROR:asyncio:Future exception was never retrieved
future: <Future finished exception=ClientConnectionError('Connection lost: [Errno 104] Connection reset by peer')>
Traceback (most recent call last):
  File "/home/hogunkee/.local/share/uv/python/cpython-3.11.13-linux-x86_64-gnu/lib/python3.11/asyncio/selector_events.py", line 974, in _read_ready__get_buffer
    nbytes = self._sock.recv_into(buf)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
ConnectionResetError: [Errno 104] Connection reset by peer

The above exception was the direct cause of the following exception:

aiohttp.client_exceptions.ClientConnectionError: Connection lost: [Errno 104] Connection reset by peer


Actions shape: (10, 8)


In [10]:
example.keys()

dict_keys(['observation/exterior_image_1_left', 'observation/wrist_image_left', 'observation/joint_position', 'observation/gripper_position', 'prompt'])

In [11]:
example['observation/gripper_position']

array([0.49965558])

In [12]:
example['observation/joint_position']

array([0.9791981 , 0.31966934, 0.30106429, 0.36159209, 0.85460669,
       0.15137372, 0.29608146])

In [13]:
example['observation/wrist_image_left'].shape

(224, 224, 3)

# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [None]:
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)

Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [None]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)